#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# context.R: SparkContext driven functions

getMinPartitions <- function(sc, minPartitions) {
  if (is.null(minPartitions)) {
    defaultParallelism <- callJMethod(sc, "defaultParallelism")
    minPartitions <- min(defaultParallelism, 2)
  }
  as.integer(minPartitions)
}

#' Create an RDD from a text file.
#'
#' This function reads a text file from HDFS, a local file system (available on all
#' nodes), or any Hadoop-supported file system URI, and creates an
#' RDD of strings from it. The text files must be encoded as UTF-8.
#'
#' @param sc SparkContext to use
#' @param path Path of file to read. A vector of multiple paths is allowed.
#' @param minPartitions Minimum number of partitions to be created. If NULL, the default
#'  value is chosen based on available parallelism.
#' @return RDD where each item is of type \code{character}
#' @noRd
#' @examples
#'\dontrun{
#'  sc <- sparkR.init()
#'  lines <- textFile(sc, "myfile.txt")
#'}
textFile <- function(sc, path, minPartitions = NULL) {
  # Allow the user to have a more flexible definition of the text file path
  path <- suppressWarnings(normalizePath(path))
  # Convert a string vector of paths to a string containing comma separated paths
  path <- paste(path, collapse = ",")

  jrdd <- callJMethod(sc, "textFile", path, getMinPartitions(sc, minPartitions))
  # jrdd is of type JavaRDD[String]
  RDD(jrdd, "string")
}

#' Load an RDD saved as a SequenceFile containing serialized objects.
#'
#' The file to be loaded should be one that was previously generated by calling
#' saveAsObjectFile() of the RDD class.
#'
#' @param sc SparkContext to use
#' @param path Path of file to read. A vector of multiple paths is allowed.
#' @param minPartitions Minimum number of partitions to be created. If NULL, the default
#'  value is chosen based on available parallelism.
#' @return RDD containing serialized R objects.
#' @seealso saveAsObjectFile
#' @noRd
#' @examples
#'\dontrun{
#'  sc <- sparkR.init()
#'  rdd <- objectFile(sc, "myfile")
#'}
objectFile <- function(sc, path, minPartitions = NULL) {
  # Allow the user to have a more flexible definition of the text file path
  path <- suppressWarnings(normalizePath(path))
  # Convert a string vector of paths to a string containing comma separated paths
  path <- paste(path, collapse = ",")

  jrdd <- callJMethod(sc, "objectFile", path, getMinPartitions(sc, minPartitions))
  # Assume the RDD contains serialized R objects.
  RDD(jrdd, "byte")
}

makeSplits <- function(numSerializedSlices, length) {
  # Generate the slice ids to put each row
  # For instance, for numSerializedSlices of 22, length of 50
  #  [1]  0  0  2  2  4  4  6  6  6  9  9 11 11 13 13 15 15 15 18 18 20 20 22 22 22
  # [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47
  # Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced.
  # We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD
  if (numSerializedSlices > 0) {
    unlist(lapply(0: (numSerializedSlices - 1), function(x) {
      # nolint start
      start <- trunc((as.numeric(x) * length) / numSerializedSlices)
      end <- trunc(((as.numeric(x) + 1) * length) / numSerializedSlices)
      # nolint end
      rep(start, end - start)
    }))
  } else {
    1
  }
}

#' Create an RDD from a homogeneous list or vector.
#'
#' This function creates an RDD from a local homogeneous list in R. The elements
#' in the list are split into \code{numSlices} slices and distributed to nodes
#' in the cluster.
#'
#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MiB), the function
#' will write it to disk and send the file name to JVM. Also to make sure each slice is not
#' larger than that limit, number of slices may be increased.
#'
#' In 2.2.0 we are changing how the numSlices are used/computed to handle
#' 1 < (length(coll) / numSlices) << length(coll) better, and to get the exact number of slices.
#' This change affects both createDataFrame and spark.lapply.
#' In the specific one case that it is used to convert R native object into SparkDataFrame, it has
#' always been kept at the default of 1. In the case the object is large, we are explicitly setting
#' the parallism to numSlices (which is still 1).
#'
#' Specifically, we are changing to split positions to match the calculation in positions() of
#' ParallelCollectionRDD in Spark.
#'
#' @param sc SparkContext to use
#' @param coll collection to parallelize
#' @param numSlices number of partitions to create in the RDD
#' @return an RDD created from this collection
#' @noRd
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:10, 2)
#' # The RDD should contain 10 elements
#' length(rdd)
#'}
parallelize <- function(sc, coll, numSlices = 1) {
  # TODO: bound/safeguard numSlices
  # TODO: unit tests for if the split works for all primitives
  # TODO: support matrix, data frame, etc

  # Note, for data.frame, createDataFrame turns it into a list before it calls here.
  # nolint start
  # suppress lintr warning: Place a space before left parenthesis, except in a function call.
  if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) {
  # nolint end
    if (is.data.frame(coll)) {
      message(paste("context.R: A data frame is parallelized by columns."))
    } else {
      if (is.matrix(coll)) {
        message(paste("context.R: A matrix is parallelized by elements."))
      } else {
        message(paste("context.R: parallelize() currently only supports lists and vectors.",
                      "Calling as.list() to coerce coll into a list."))
      }
    }
    coll <- as.list(coll)
  }

  sizeLimit <- getMaxAllocationLimit(sc)
  objectSize <- object.size(coll)
  len <- length(coll)

  # For large objects we make sure the size of each slice is also smaller than sizeLimit
  numSerializedSlices <- min(len, max(numSlices, ceiling(objectSize / sizeLimit)))

  slices <- split(coll, makeSplits(numSerializedSlices, len))

  # Serialize each slice: obtain a list of raws, or a list of lists (slices) of
  # 2-tuples of raws
  serializedSlices <- lapply(slices, serialize, connection = NULL)

  # The RPC backend cannot handle arguments larger than 2GB (INT_MAX)
  # If serialized data is safely less than that threshold we send it over the PRC channel.
  # Otherwise, we write it to a file and send the file name
  if (objectSize < sizeLimit) {
    jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices)
  } else {
    if (callJStatic("org.apache.spark.api.r.RUtils", "isEncryptionEnabled", sc)) {
      connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
      # the length of slices here is the parallelism to use in the jvm's sc.parallelize()
      parallelism <- as.integer(numSlices)
      jserver <- newJObject("org.apache.spark.api.r.RParallelizeServer", sc, parallelism)
      authSecret <- callJMethod(jserver, "secret")
      port <- callJMethod(jserver, "port")
      conn <- socketConnection(
        port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
      doServerAuth(conn, authSecret)
      writeToConnection(serializedSlices, conn)
      jrdd <- callJMethod(jserver, "getResult")
    } else {
      fileName <- writeToTempFile(serializedSlices)
      jrdd <- tryCatch(callJStatic(
          "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)),
        finally = {
          file.remove(fileName)
      })
    }
  }

  RDD(jrdd, "byte")
}

getMaxAllocationLimit <- function(sc) {
  conf <- callJMethod(sc, "getConf")
  as.numeric(
    callJMethod(conf,
      "get",
      "spark.r.maxAllocationLimit",
      toString(.Machine$integer.max / 10) # Default to a safe value: 200MB
  ))
}

writeToConnection <- function(serializedSlices, conn) {
  tryCatch({
    for (slice in serializedSlices) {
      writeBin(as.integer(length(slice)), conn, endian = "big")
      writeBin(slice, conn, endian = "big")
    }
  }, finally = {
    close(conn)
  })
}

writeToTempFile <- function(serializedSlices) {
  fileName <- tempfile()
  conn <- file(fileName, "wb")
  writeToConnection(serializedSlices, conn)
  fileName
}

#' Include this specified package on all workers
#'
#' This function can be used to include a package on all workers before the
#' user's code is executed. This is useful in scenarios where other R package
#' functions are used in a function passed to functions like \code{lapply}.
#' NOTE: The package is assumed to be installed on every node in the Spark
#' cluster.
#'
#' @param sc SparkContext to use
#' @param pkg Package name
#' @noRd
#' @examples
#'\dontrun{
#'  library(Matrix)
#'
#'  sc <- sparkR.init()
#'  # Include the matrix library we will be using
#'  includePackage(sc, Matrix)
#'
#'  generateSparse <- function(x) {
#'    sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3))
#'  }
#'
#'  rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse)
#'  collect(rdd)
#'}
includePackage <- function(sc, pkg) {
  pkg <- as.character(substitute(pkg))
  if (exists(".packages", .sparkREnv)) {
    packages <- .sparkREnv$.packages
  } else {
    packages <- list()
  }
  packages <- c(packages, pkg)
  .sparkREnv$.packages <- packages
}

#' Broadcast a variable to all workers
#'
#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast}
#' object for reading it in distributed functions.
#'
#' @param sc Spark Context to use
#' @param object Object to be broadcast
#' @noRd
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' rdd <- parallelize(sc, 1:2, 2L)
#'
#' # Large Matrix object that we want to broadcast
#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000))
#' randomMatBr <- broadcastRDD(sc, randomMat)
#'
#' # Use the broadcast variable inside the function
#' useBroadcast <- function(x) {
#'   sum(value(randomMatBr) * x)
#' }
#' sumRDD <- lapply(rdd, useBroadcast)
#'}
broadcastRDD <- function(sc, object) {
  objName <- as.character(substitute(object))
  serializedObj <- serialize(object, connection = NULL)

  jBroadcast <- callJMethod(sc, "broadcast", serializedObj)
  id <- as.character(callJMethod(jBroadcast, "id"))

  Broadcast(id, object, jBroadcast, objName)
}

#' Set the checkpoint directory
#'
#' Set the directory under which RDDs are going to be checkpointed. The
#' directory must be an HDFS path if running on a cluster.
#'
#' @param sc Spark Context to use
#' @param dirName Directory path
#' @noRd
#' @examples
#'\dontrun{
#' sc <- sparkR.init()
#' setCheckpointDir(sc, "~/checkpoint")
#' rdd <- parallelize(sc, 1:2, 2L)
#' checkpoint(rdd)
#'}
setCheckpointDirSC <- function(sc, dirName) {
  invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName))))
}

#' Add a file or directory to be downloaded with this Spark job on every node.
#'
#' The path passed can be either a local file, a file in HDFS (or other Hadoop-supported
#' filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
#' use spark.getSparkFiles(fileName) to find its download location.
#'
#' A directory can be given if the recursive option is set to true.
#' Currently directories are only supported for Hadoop-supported filesystems.
#' Refer Hadoop-supported filesystems at
#' \url{https://cwiki.apache.org/confluence/display/HADOOP2/HCFS}.
#'
#' Note: A path can be added only once. Subsequent additions of the same path are ignored.
#'
#' @rdname spark.addFile
#' @param path The path of the file to be added
#' @param recursive Whether to add files recursively from the path. Default is FALSE.
#' @examples
#'\dontrun{
#' spark.addFile("~/myfile")
#'}
#' @note spark.addFile since 2.1.0
spark.addFile <- function(path, recursive = FALSE) {
  sc <- getSparkContext()
  invisible(callJMethod(sc, "addFile", suppressWarnings(normalizePath(path)), recursive))
}

#' Get the root directory that contains files added through spark.addFile.
#'
#' @rdname spark.getSparkFilesRootDirectory
#' @return the root directory that contains files added through spark.addFile
#' @examples
#'\dontrun{
#' spark.getSparkFilesRootDirectory()
#'}
#' @note spark.getSparkFilesRootDirectory since 2.1.0
spark.getSparkFilesRootDirectory <- function() { # nolint
  if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") {
    # Running on driver.
    callJStatic("org.apache.spark.SparkFiles", "getRootDirectory")
  } else {
    # Running on worker.
    Sys.getenv("SPARKR_SPARKFILES_ROOT_DIR")
  }
}

#' Get the absolute path of a file added through spark.addFile.
#'
#' @rdname spark.getSparkFiles
#' @param fileName The name of the file added through spark.addFile
#' @return the absolute path of a file added through spark.addFile.
#' @examples
#'\dontrun{
#' spark.getSparkFiles("myfile")
#'}
#' @note spark.getSparkFiles since 2.1.0
spark.getSparkFiles <- function(fileName) {
  if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") {
    # Running on driver.
    callJStatic("org.apache.spark.SparkFiles", "get", as.character(fileName))
  } else {
    # Running on worker.
    file.path(spark.getSparkFilesRootDirectory(), as.character(fileName))
  }
}

#' Run a function over a list of elements, distributing the computations with Spark
#'
#' Run a function over a list of elements, distributing the computations with Spark. Applies a
#' function in a manner that is similar to doParallel or lapply to elements of a list.
#' The computations are distributed using Spark. It is conceptually the same as the following code:
#'   lapply(list, func)
#'
#' Known limitations:
#' \itemize{
#'    \item variable scoping and capture: compared to R's rich support for variable resolutions,
#'    the distributed nature of SparkR limits how variables are resolved at runtime. All the
#'    variables that are available through lexical scoping are embedded in the closure of the
#'    function and available as read-only variables within the function. The environment variables
#'    should be stored into temporary variables outside the function, and not directly accessed
#'    within the function.
#'
#'   \item loading external packages: In order to use a package, you need to load it inside the
#'   closure. For example, if you rely on the MASS module, here is how you would use it:
#'   \preformatted{
#'     train <- function(hyperparam) {
#'       library(MASS)
#'       lm.ridge("y ~ x+z", data, lambda=hyperparam)
#'       model
#'     }
#'   }
#' }
#'
#' @rdname spark.lapply
#' @param list the list of elements
#' @param func a function that takes one argument.
#' @return a list of results (the exact type being determined by the function)
#' @examples
#'\dontrun{
#' sparkR.session()
#' doubled <- spark.lapply(1:10, function(x) {2 * x})
#'}
#' @note spark.lapply since 2.0.0
spark.lapply <- function(list, func) {
  sc <- getSparkContext()
  rdd <- parallelize(sc, list, length(list))
  results <- map(rdd, func)
  local <- collectRDD(results)
  local
}

#' Set new log level
#'
#' Set new log level: "ALL", "DEBUG", "ERROR", "FATAL", "INFO", "OFF", "TRACE", "WARN"
#'
#' @rdname setLogLevel
#' @param level New log level
#' @examples
#'\dontrun{
#' setLogLevel("ERROR")
#'}
#' @note setLogLevel since 2.0.0
setLogLevel <- function(level) {
  sc <- getSparkContext()
  invisible(callJMethod(sc, "setLogLevel", level))
}

#' Set checkpoint directory
#'
#' Set the directory under which SparkDataFrame are going to be checkpointed. The directory must be
#' an HDFS path if running on a cluster.
#'
#' @rdname setCheckpointDir
#' @param directory Directory path to checkpoint to
#' @seealso \link{checkpoint}
#' @examples
#'\dontrun{
#' setCheckpointDir("/checkpoint")
#'}
#' @note setCheckpointDir since 2.2.0
setCheckpointDir <- function(directory) {
  sc <- getSparkContext()
  invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(directory))))
}
